# encoding: utf-8


import keras
import matplotlib as mpl
mpl.use('Agg')  # 服务器使用matplotlib
import matplotlib.pyplot as plt


class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.loss = {'batch':[], 'epoch':[]}
        self.acc = {'batch':[], 'epoch':[]}
        self.val_loss = {'batch':[], 'epoch':[]}
        self.val_acc = {'batch':[], 'epoch':[]}

    def on_batch_end(self, batch, logs={}):
        self.loss['batch'].append(logs.get('loss'))
        self.acc['batch'].append(logs.get('acc'))
        self.val_loss['batch'].append(logs.get('val_loss'))
        self.val_acc['batch'].append(logs.get('val_acc'))

    def on_epoch_end(self, batch, logs={}):
        self.loss['epoch'].append(logs.get('loss'))
        self.acc['epoch'].append(logs.get('acc'))
        self.val_loss['epoch'].append(logs.get('val_loss'))
        self.val_acc['epoch'].append(logs.get('val_acc'))

    def loss_plot(self, loss_type, name):
        iters = range(len(self.loss[loss_type]))
        plt.figure()
        plt.plot(iters, self.loss[loss_type], 'g', label='train loss')
        if loss_type == 'epoch':
            plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
        plt.grid(True)
        plt.xlabel(loss_type)
        plt.ylabel('loss')
        plt.legend(loc='upper right')
        plt.savefig('{}'.format(name))

    def acc_plot(self, acc_type, name):
        iters = range(len(self.acc[acc_type]))
        plt.figure()
        plt.plot(iters, self.acc[acc_type], 'r', label='train acc')
        if acc_type == 'epoch':
            plt.plot(iters, self.val_acc[acc_type], 'b', label='val acc')
        plt.grid(True)
        plt.xlabel(acc_type)
        plt.ylabel('acc')
        plt.legend(loc='upper right')
        plt.savefig('{}'.format(name))
